{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Distributed training of a DLinear time-series model\n",
    "\n",
    "<div align=\"left\">\n",
    "<a target=\"_blank\" href=\"https://console.anyscale.com/\"><img src=\"https://img.shields.io/badge/🚀%20Run%20on-Anyscale-9hf\"></a>&nbsp;\n",
    "<a href=\"https://github.com/anyscale/e2e-timeseries\" role=\"button\"><img src=\"https://img.shields.io/static/v1?label=&message=View%20On%20GitHub&color=586069&logo=github&labelColor=2f363d\"></a>\n",
    "</div>\n",
    "\n",
    "\n",
    "This tutorial executes a distributed training workload that connects the following steps with heterogeneous compute requirements:\n",
    "\n",
    "* Preprocessing the dataset with Ray Data\n",
    "* Distributed training of a DLinear model with Ray Train\n",
    "\n",
    "Note: This tutorial doesn't including tuning of the model. See Ray Tune for experiment execution and hyperparameter tuning.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-timeseries/master/images/distributed_training.png\" width=800>\n",
    "\n",
    "Before starting, run the setup steps outlined in the README.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Enable Ray Train v2. This is the default in an upcoming release.\n",
    "os.environ[\"RAY_TRAIN_V2_ENABLED\"] = \"1\"\n",
    "# Now it's safe to import from ray.train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable importing from e2e_timeseries module.\n",
    "import sys\n",
    "\n",
    "sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), os.pardir)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import tempfile\n",
    "import time\n",
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "import ray\n",
    "from ray import train\n",
    "from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig, get_dataset_shard\n",
    "from ray.train.torch import TorchTrainer\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch import optim\n",
    "\n",
    "import e2e_timeseries\n",
    "from e2e_timeseries.data_factory import data_provider\n",
    "from e2e_timeseries.metrics import metric\n",
    "from e2e_timeseries.model import DLinear\n",
    "from e2e_timeseries.tools import adjust_learning_rate\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the Ray cluster with the `e2e_timeseries` module, so that newly-spawned workers can import from it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ray.init(runtime_env={\"py_modules\": [e2e_timeseries]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Anatomy of a Ray Train job\n",
    "\n",
    "Ray Train provides the Trainer abstraction, which handles the complexity of distributed training. The Trainer takes a few inputs:\n",
    "\n",
    "- Training function: The Python code that executes on each distributed training worker.\n",
    "- Train configuration: Contains the hyperparameters that the Trainer passes to the training function.\n",
    "- Scaling configuration: Defines the scaling behavior of the job and whether to use accelerators.\n",
    "- Run configuration: Controls checkpointing and specifies storage locations.\n",
    "\n",
    "The Trainer then launches the workers across the Ray Cluster according to the scaling configuration and runs the training function on each worker.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-timeseries/master/images/ray_train_graph.png\" width=800>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The train configuration\n",
    "\n",
    "First, set up the training configuration for the trainable function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    # Basic config.\n",
    "    \"train_only\": False,\n",
    "    # Data loader args.\n",
    "    \"num_data_workers\": 10,\n",
    "    # Forecasting task type.\n",
    "    # S: univariate predict univariate\n",
    "    # M: multivariate predict univariate\n",
    "    # MS: multivariate predict multivariate\n",
    "    \"features\": \"S\",\n",
    "    \"target\": \"OT\",  # Target variable name for prediction\n",
    "    # Forecasting task args.\n",
    "    \"seq_len\": 96,\n",
    "    \"label_len\": 48,\n",
    "    \"pred_len\": 96,\n",
    "    # DLinear-specific args.\n",
    "    \"individual\": False,\n",
    "    # Optimization args.\n",
    "    \"num_replicas\": 4,\n",
    "    \"train_epochs\": 10,\n",
    "    \"batch_size\": 32,\n",
    "    \"learning_rate\": 0.005,\n",
    "    \"loss\": \"mse\",\n",
    "    \"lradj\": \"type1\",\n",
    "    \"use_amp\": False,\n",
    "    # Other args.\n",
    "    \"seed\": 42,\n",
    "}\n",
    "\n",
    "# Dataset-specific args.\n",
    "config[\"data\"] = \"ETTh1\"\n",
    "if config[\"features\"] == \"S\":  # S: univariate predict univariate\n",
    "    config[\"enc_in\"] = 1\n",
    "else:  # M or MS\n",
    "    config[\"enc_in\"] = 7  # ETTh1 has 7 features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuring persistent storage\n",
    "\n",
    "Next, configure the storage that the workers use to store checkpoints and artifacts. The storage needs to be accessible from all workers in the cluster. This storage can be S3, NFS, or another network-attached solution. Anyscale simplifies this process by automatically creating and mounting [shared storage options](https://docs.anyscale.com/configuration/storage/#storage-shared-across-nodes) on every cluster node, ensuring that model artifacts can are readable and writeable consistently across the distributed environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"checkpoints\"] = \"/mnt/cluster_storage/checkpoints\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that passing large objects such as model weights and datasets through this configuration is an anti-pattern. Doing so can cause high serialization and deserialization overhead. Instead, it's preferred to initialize these objects within the training function. Alternatively, \n",
    "\n",
    "For the purposes of demonstration, enable smoke test mode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"smoke_test\"] = True\n",
    "if config[\"smoke_test\"]:\n",
    "    print(\"--- RUNNING SMOKE TEST ---\")\n",
    "    config[\"train_epochs\"] = 2\n",
    "    config[\"batch_size\"] = 2\n",
    "    config[\"num_data_workers\"] = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up a training function\n",
    "\n",
    "The training function holds the model training logic which each distributed training worker executes. The TorchTrainer passes a configuration dictionary as input to the training function. Ray Train provides a few convenience functions for distributed training:\n",
    "\n",
    "- Automatically moving each model replica to the correct device.\n",
    "- Setting up the parallelization strategy (for example, distributed data parallel or fully sharded data parallel).\n",
    "- Setting up PyTorch data loaders for distributed execution, including auto-transfering objects to the correct device.\n",
    "- Reporting metrics and handling distributed checkpointing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_loop_per_worker(config: dict):\n",
    "    \"\"\"Main training loop run on Ray Train workers.\"\"\"\n",
    "\n",
    "    random.seed(config[\"seed\"])\n",
    "    torch.manual_seed(config[\"seed\"])\n",
    "    np.random.seed(config[\"seed\"])\n",
    "\n",
    "    # Automatically determine device based on availability.\n",
    "    device = train.torch.get_device()\n",
    "\n",
    "    def _postprocess_preds_and_targets(raw_pred, batch_y, config):\n",
    "        pred_len = config[\"pred_len\"]\n",
    "        f_dim_start_index = -1 if config[\"features\"] == \"MS\" else 0\n",
    "\n",
    "        # Slice for prediction length first.\n",
    "        outputs_pred_len = raw_pred[:, -pred_len:, :]\n",
    "        batch_y_pred_len = batch_y[:, -pred_len:, :]\n",
    "\n",
    "        # Then slice for features.\n",
    "        final_pred = outputs_pred_len[:, :, f_dim_start_index:]\n",
    "        final_target = batch_y_pred_len[:, :, f_dim_start_index:]\n",
    "\n",
    "        return final_pred, final_target\n",
    "\n",
    "    # === Build Model ===\n",
    "    model = DLinear(config).float()\n",
    "    # Convenience function to move the model to the correct device and set up\n",
    "    # parallel strategy.\n",
    "    model = train.torch.prepare_model(model)\n",
    "\n",
    "    # === Get Data ===\n",
    "    train_ds = get_dataset_shard(\"train\")\n",
    "\n",
    "    # === Optimizer and Criterion ===\n",
    "    model_optim = optim.Adam(model.parameters(), lr=config[\"learning_rate\"])\n",
    "    criterion = nn.MSELoss()\n",
    "\n",
    "    # === AMP Scaler ===\n",
    "    scaler = None\n",
    "    if config[\"use_amp\"]:\n",
    "        scaler = torch.amp.GradScaler(\"cuda\")\n",
    "\n",
    "    # === Training Loop ===\n",
    "    for epoch in range(config[\"train_epochs\"]):\n",
    "        model.train()\n",
    "        train_loss_epoch = []\n",
    "        epoch_start_time = time.time()\n",
    "\n",
    "        # Iterate over Ray Dataset batches. The dataset now yields dicts {'x': numpy_array, 'y': numpy_array}\n",
    "        # iter_torch_batches converts these to Torch tensors and move to device.\n",
    "        for batch in train_ds.iter_torch_batches(batch_size=config[\"batch_size\"], device=device, dtypes=torch.float32):\n",
    "            model_optim.zero_grad()\n",
    "            x = batch[\"x\"]\n",
    "            y = batch[\"y\"]\n",
    "\n",
    "            # Forward pass\n",
    "            if config[\"use_amp\"]:\n",
    "                with torch.amp.autocast(\"cuda\"):\n",
    "                    raw_preds = model(x)\n",
    "                    predictions, targets = _postprocess_preds_and_targets(raw_preds, y, config)\n",
    "                    loss = criterion(predictions, targets)\n",
    "            else:\n",
    "                raw_preds = model(x)\n",
    "                predictions, targets = _postprocess_preds_and_targets(raw_preds, y, config)\n",
    "                loss = criterion(predictions, targets)\n",
    "\n",
    "            train_loss_epoch.append(loss.item())\n",
    "\n",
    "            # Backward pass.\n",
    "            if config[\"use_amp\"]:\n",
    "                scaler.scale(loss).backward()\n",
    "                scaler.step(model_optim)\n",
    "                scaler.update()\n",
    "            else:\n",
    "                loss.backward()\n",
    "                model_optim.step()\n",
    "\n",
    "        # === End of Epoch ===\n",
    "        epoch_train_loss = np.average(train_loss_epoch)\n",
    "        epoch_duration = time.time() - epoch_start_time\n",
    "\n",
    "        results_dict = {\n",
    "            \"epoch\": epoch + 1,\n",
    "            \"train/loss\": epoch_train_loss,\n",
    "            \"epoch_duration_s\": epoch_duration,\n",
    "        }\n",
    "\n",
    "        # === Validation ===\n",
    "        if not config[\"train_only\"]:\n",
    "            val_ds = get_dataset_shard(\"val\")\n",
    "\n",
    "            model.eval()\n",
    "            all_preds = []\n",
    "            all_trues = []\n",
    "            with torch.no_grad():\n",
    "                for batch in val_ds.iter_torch_batches(batch_size=config[\"batch_size\"], device=device, dtypes=torch.float32):\n",
    "                    x, y = batch[\"x\"], batch[\"y\"]\n",
    "\n",
    "                    if config[\"use_amp\"] and torch.cuda.is_available():\n",
    "                        with torch.amp.autocast(\"cuda\"):\n",
    "                            raw_preds = model(x)\n",
    "                    else:\n",
    "                        raw_preds = model(x)\n",
    "\n",
    "                    predictions, targets = _postprocess_preds_and_targets(raw_preds, y, config)\n",
    "\n",
    "                    all_preds.append(predictions.detach().cpu().numpy())\n",
    "                    all_trues.append(targets.detach().cpu().numpy())\n",
    "\n",
    "            all_preds = np.concatenate(all_preds, axis=0)\n",
    "            all_trues = np.concatenate(all_trues, axis=0)\n",
    "\n",
    "            mae, mse, rmse, mape, mspe, rse = metric(all_preds, all_trues)\n",
    "\n",
    "            results_dict[\"val/loss\"] = mse\n",
    "            results_dict[\"val/mae\"] = mae\n",
    "            results_dict[\"val/rmse\"] = rmse\n",
    "            results_dict[\"val/mape\"] = mape\n",
    "            results_dict[\"val/mspe\"] = mspe\n",
    "            results_dict[\"val/rse\"] = rse\n",
    "\n",
    "            print(f\"Epoch {epoch + 1}: Train Loss: {epoch_train_loss:.7f}, Val Loss: {mse:.7f}, Val MSE: {mse:.7f} (Duration: {epoch_duration:.2f}s)\")\n",
    "\n",
    "        # === Reporting and Checkpointing ===\n",
    "        if train.get_context().get_world_rank() == 0:\n",
    "            with tempfile.TemporaryDirectory() as temp_checkpoint_dir:\n",
    "                torch.save(\n",
    "                    {\n",
    "                        \"epoch\": epoch,\n",
    "                        \"model_state_dict\": model.module.state_dict() if hasattr(model, \"module\") else model.state_dict(),\n",
    "                        \"optimizer_state_dict\": model_optim.state_dict(),\n",
    "                        \"train_args\": config,\n",
    "                    },\n",
    "                    os.path.join(temp_checkpoint_dir, \"checkpoint.pt\"),\n",
    "                )\n",
    "                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)\n",
    "                train.report(metrics=results_dict, checkpoint=checkpoint)\n",
    "        else:\n",
    "            train.report(metrics=results_dict, checkpoint=None)\n",
    "\n",
    "        adjust_learning_rate(model_optim, epoch + 1, config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **Ray Train Benefits:**\n",
    "> \n",
    "> **Multi-node orchestration**: Automatically handles multi-node, multi-GPU setup without manual SSH or hostfile configurations\n",
    "> \n",
    "> **Built-in fault tolerance**: Supports automatic retry of failed workers and can continue from the last checkpoint\n",
    "> \n",
    "> **Flexible training strategies**: Supports various parallelism strategies beyond just data parallel training\n",
    "> \n",
    "> **Heterogeneous cluster support**: Define per-worker resource requirements and run on mixed hardware\n",
    "> \n",
    "> Ray Train integrates with popular frameworks like PyTorch, TensorFlow, XGBoost, and more. For enterprise needs, [RayTurbo Train](https://docs.anyscale.com/rayturbo/rayturbo-train) offers additional features like elastic training, advanced monitoring, and performance optimization.\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-timeseries/master/images/train_integrations.png\" width=800>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the scaling config\n",
    "\n",
    "Next, set up the scaling configuration. This example assigns one model replica per GPU in the cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scaling_config = ScalingConfig(num_workers=config[\"num_replicas\"], use_gpu=True, resources_per_worker={\"GPU\": 1})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Checkpointing configuration\n",
    "\n",
    "Checkpointing enables you to resume training from the last checkpoint in case of interruptions or failures. Checkpointing is particularly useful for long-running training sessions. [`CheckpointConfig`](https://docs.ray.io/en/latest/train/api/doc/ray.train.CheckpointConfig.html) makes it easy to customize the checkpointing policy.\n",
    "\n",
    "This example demonstrates how to keep a maximum of two model checkpoints based on their minimum validation loss score.\n",
    "\n",
    "Note: Once you enable checkpointing, you can follow [this guide](https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html) to enable fault tolerance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adjust run name during smoke tests.\n",
    "run_name_prefix = \"SmokeTest_\" if config[\"smoke_test\"] else \"\"\n",
    "run_name = f\"{run_name_prefix}DLinear_{config['data']}_{config['features']}_{config['target']}_{time.strftime('%Y%m%d_%H%M%S')}\"\n",
    "\n",
    "run_config = RunConfig(\n",
    "    storage_path=config[\"checkpoints\"],\n",
    "    name=run_name,\n",
    "    checkpoint_config=CheckpointConfig(num_to_keep=2, checkpoint_score_attribute=\"val/loss\", checkpoint_score_order=\"min\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets\n",
    "\n",
    "Ray Data is a library that enables distributed and streaming pre-processing of data. It's possible to convert an existing PyTorch Dataset to a Ray Dataset using `ray_ds = ray.data.from_torch(pytorch_ds)`.\n",
    "\n",
    "To distribute the Ray Dataset to each training worker, pass the datasets as a dictionary to the `datasets` parameter. Later, calling [`get_dataset_shard()`](https://docs.ray.io/en/master/train/api/doc/ray.train.get_dataset_shard.html#ray.train.get_dataset_shard) inside the training function automatically fetches a shard of the dataset assigned to that worker.\n",
    "\n",
    "This tutorial uses the [Electricity Transformer dataset](https://github.com/zhouhaoyi/ETDataset) (ETDataset), which measures the oil temperature of dozens of electrical stations in China over two years."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets = {\"train\": data_provider(config, flag=\"train\")}\n",
    "if not config[\"train_only\"]:\n",
    "    datasets[\"val\"] = data_provider(config, flag=\"val\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because Ray Data lazily evaluates Ray Datasets, use `show(1)` to materialize a sample of the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets[\"train\"].show(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial, the training objective is to predict future oil temperatures `y` given a window of past oil temperatures `x`.\n",
    "\n",
    "Executing `.show(1)` streams a single record through the pre-processing pipeline, standardizing the temperature column with zero-centered and unit-normalized values.\n",
    "\n",
    "Next, combine all the inputs to initialize the `TorchTrainer`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = TorchTrainer(\n",
    "    train_loop_per_worker=train_loop_per_worker,\n",
    "    train_loop_config=config,\n",
    "    scaling_config=scaling_config,\n",
    "    run_config=run_config,\n",
    "    datasets=datasets,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, execute training using the `.fit()` method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Run Training ===\n",
    "print(\"Starting Ray Train job...\")\n",
    "result = trainer.fit()\n",
    "print(\"Training finished!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observe that at the beginning of the training job, Ray immediately requests four GPU nodes defined in the `ScalingConfig`. Because you enabled \"Auto-select worker nodes,\" Anyscale automatically provisions any missing compute.\n",
    "\n",
    "You can monitor the scaling behavior and cluster resource utilization on the Ray Dashboard:\n",
    "\n",
    "<img src=\"https://raw.githubusercontent.com/anyscale/e2e-timeseries/master/images/train_metrics.png\" width=800>\n",
    "\n",
    "The Ray Train job returns a `ray.train.Result` object, which contains important properties such as metrics, checkpoint info, and error details:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics = result.metrics\n",
    "metrics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The metrics should look something like the following:\n",
    "\n",
    "```python\n",
    "{'epoch': 2,\n",
    " 'train/loss': 0.33263104565833745,\n",
    " 'epoch_duration_s': 0.9015529155731201,\n",
    " 'val/loss': 0.296540230512619,\n",
    " 'val/mae': 0.4813770353794098,\n",
    " 'val/rmse': 0.544555075738551,\n",
    " 'val/mape': 9.20688533782959,\n",
    " 'val/mspe': 2256.628662109375,\n",
    " 'val/rse': 1.3782594203948975}\n",
    "```\n",
    "\n",
    "Now that the model has completed training, find the checkpoint with the lowest loss in the [`Result`](https://docs.ray.io/en/master/train/api/doc/ray.train.Result.html) object."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Post-Training ===\n",
    "if result.best_checkpoints:\n",
    "    best_checkpoint_path = None\n",
    "    if not config[\"train_only\"] and \"val/loss\" in result.metrics_dataframe:\n",
    "        best_checkpoint = result.get_best_checkpoint(metric=\"val/loss\", mode=\"min\")\n",
    "        if best_checkpoint:\n",
    "            best_checkpoint_path = best_checkpoint.path\n",
    "    elif \"train/loss\" in result.metrics_dataframe:  # Fallback or if train_only\n",
    "        best_checkpoint = result.get_best_checkpoint(metric=\"train/loss\", mode=\"min\")\n",
    "        if best_checkpoint:\n",
    "            best_checkpoint_path = best_checkpoint.path\n",
    "\n",
    "    if best_checkpoint_path:\n",
    "        print(\"Best checkpoint found:\")\n",
    "        print(f\"  Directory: {best_checkpoint_path}\")\n",
    "\n",
    "        best_checkpoint_metadata_fpath = os.path.join(\n",
    "            \"/mnt/cluster_storage/checkpoints\", \"best_checkpoint_path.txt\"\n",
    "        )\n",
    "\n",
    "        with open(best_checkpoint_metadata_fpath, \"w\") as f:\n",
    "            # Store the best checkpoint path in a file for later use\n",
    "            f.write(f\"{best_checkpoint_path}/checkpoint.pt\")\n",
    "            print(\"Train run metadata saved.\")\n",
    "    else:\n",
    "        print(\"Could not retrieve the best checkpoint based on available metrics.\")\n",
    "else:\n",
    "    print(\"No checkpoints were saved during training.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
